import json
import re

from utils_for_llm import *
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback, TrainerState, TrainerControl
from datasets import load_dataset, DatasetDict, Dataset
import warnings
from accelerate import Accelerator
from accelerate.utils import gather_object
from codebleu import calc_codebleu
import os
import torch.distributed as dist
from datetime import timedelta
import time
from transformers import DataCollatorWithPadding

# Ignore all warnings
warnings.filterwarnings("ignore")

os.environ['TORCH_NCCL_BLOCKING_WAIT'] = '1'
os.environ['TORCH_NCCL_ASYNC_ERROR_HANDLING'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = "False"

if os.getenv('PYCHARM_HOSTED') != '1':
    dist.init_process_group(backend='nccl', timeout=timedelta(hours=6))

with open('./prompt_for_API_ICL.pkl', 'rb') as fp:
    prompt_for_API_ICL = pickle.load(fp)



accelerator = Accelerator(mixed_precision='bf16')

if accelerator.state.deepspeed_plugin:
    deepspeed_config = accelerator.state.deepspeed_plugin.deepspeed_config
    zero_version = deepspeed_config.get('zero_optimization', {}).get("stage")
    print(zero_version)
else:
    zero_version = -1




def predict_on_validation_BATCH(model, tokenizer, eval_dataset, batch_size=1):
    model.eval()


    for sample in eval_dataset:
        # prompt = tokenizer.apply_chat_template(sample['prompt'])
        sample['input_for_model'] = tokenizer.apply_chat_template(sample['prompt'])
        sample['input_for_model'] =  tokenizer(sample['input_for_model'], max_length=8192, padding=True, truncation=True, add_special_tokens=False)




    for sample in tqdm(eval_dataset):
        with torch.inference_mode():
            inputs = {key: value.to(accelerator.device) for key, value in sample['input_for_model'].items() if key != 'labels'}
            outputs = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'],
                                     pad_token_id=tokenizer.pad_token_id,
                                     max_length=8192, num_return_sequences=1)

        original_length = len(inputs['input_ids'])

        response = tokenizer.decode(outputs[original_length:], skip_special_tokens=True)

        sample['Llama3.1-8B_response'] = response

    # Close progress bar

    model.train()

# sync GPUs and start the timer
accelerator.wait_for_everyone()


model_path = "/Pretrained_Language_Models/Meta-Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            use_cache=False,
            attn_implementation="flash_attention_2",
            device_map={"": accelerator.process_index},
        )

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"


start = time.time()

for split in ['val', 'ood']:

    prompt_for_API_ICL[split] = prompt_for_API_ICL[split].to_dict(orient='records')
    # Split the data across processes
    with accelerator.split_between_processes(prompt_for_API_ICL[split]) as eval_dataset:
        infer_result = predict_on_validation_BATCH(model, tokenizer, eval_dataset, batch_size=1)

    # Gather results from all processes
    infer_result = gather_object(infer_result)
    timediff = time.time() - start

    # 将时间差转换为分钟和小时
    minutes, seconds = divmod(timediff, 60)
    hours, minutes = divmod(minutes, 60)


    # Only save the results on the main process
    if accelerator.is_main_process:
        print('=' * 25, 'DEV','='* 25)
        print(f"Inference Time: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds")
        with open(os.path.join('./output', f'{path}_DEV_final.json'), 'w') as fp:
            json.dump(infer_result, fp, indent=4)
        print('Final reuslt is dumped to:\n ', os.path.join('./output', f'{path}_DEV_final.json'))
        print_result(infer_result)
        print('=' * 50)